import astropy.io.fits as fits
import math
import numpy as np
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
import astropy.units as u
import pickle
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
plt.rcParams.update({'font.size': 20})

def rms(x, axis=None):
    return np.sqrt(np.mean(x**2, axis=axis))

def get_flux(imagename,coord,radius,exclude_inner=False):
   image=fits.open(imagename)
   header=image[0].header
   data_copy=image[0].data.copy()
   wcs=WCS(header)
   xy=wcs.world_to_pixel(coord)
   x=xy[0][0]
   y=xy[0][0]
   radius_px=radius/(header['CDELT2']*3600.0)
   inner_radius_px=0.1/(header['CDELT2']*3600.0)
   nx=header['NAXIS1']
   ny=header['NAXIS2'] 
   mask,inverse_mask=create_circular_mask2(nx,ny, center=[x,y], radius=radius_px)
   inner_mask,inverse_inner_mask=create_circular_mask2(nx,ny, center=[x,y], radius=inner_radius_px)
   if exclude_inner:
      mask=mask-inner_mask
   mask=mask.transpose()
   inverse_mask=inverse_mask.transpose()
   npix_beam=3.14159*header['BMAJ']*header['BMIN']/(4.0*math.log(2.0))/header['CDELT2']**2
   nbeam_per_spectrum=3.14159*radius_px**2/npix_beam
   flux=np.sum(image[0].data[:,:]*mask/npix_beam)
   flux_beam=np.sum(image[0].data[:,:]*mask)
   indices=(inverse_mask==1).nonzero()
   rms_value=rms(image[0].data[indices[0],indices[1]])
   e_flux=rms_value*nbeam_per_spectrum**0.5
   e_flux_beam=rms_value*nbeam_per_spectrum**0.5

   indices=((image[0].data > 3.0*rms_value) & (mask==1)).nonzero()
   masked_data=image[0].data*mask
   flux_3sig=np.sum(masked_data[indices[0],indices[1]]/npix_beam)
   e_flux_3sig=rms_value*(len(indices[0])/npix_beam)**0.5

   #image.writeto(imagename.replace('fits','masked.fits'))
   return flux, e_flux,flux_3sig, e_flux_3sig,rms_value,flux_beam,e_flux_beam,nbeam_per_spectrum,npix_beam


def create_circular_mask(h, w, center=None, radius=None):

    if center is None: # use the middle of the image
        center = (int(w/2), int(h/2))
    if radius is None: # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w-center[0], h-center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = dist_from_center <= radius
    return mask

def create_circular_mask2(nx, ny, center=None, radius=None):
   x = np.linspace(0,nx-1, nx)
   y = np.linspace(0,ny-1, ny)
   x, y = np.meshgrid(x, y)
   mask = np.sqrt((x - center[0])**2 + (y-center[1])**2)
   inverse_mask = np.sqrt((x - center[0])**2 + (y-center[1])**2)
   for x in range(0,nx):
        for y in range(0,ny):
                if mask[x,y] < radius:
                        mask[x,y] = 1.00
                        inverse_mask[x,y] = 0
                elif mask[x,y] >= radius:
                        mask[x,y] = 0
                        inverse_mask[x,y] = 1.0
   return mask,inverse_mask




flux_dict={}
flux_dict['HDO_241']={'filename': 'V883_Ori_SB_HDO-241.558GHz_robust_2.0.image_M0_5.05_7.05kms.fits','title': 'HDO 241 GHz'}
flux_dict['HDO_225']={'filename': 'V883_Ori_SB_HDO-225.535GHz_robust_2.0.image_M0_5.05_7.05kms.fits','title': 'HDO 225 GHz'}
flux_dict['H218O_203']={'filename': 'V883_Ori_SB_H218O-203.GHz-0.4kms_robust_2.0.image_M0_5.05_7.05kms.fits'} #'V883_Ori_SB_H218O-203.GHz_robust_2.0.image.fits'} 

radecstring='05h38m18.100454s -07d02m25.99340s'
coord=SkyCoord([radecstring],frame='icrs',unit=(u.hourangle,u.deg))



for key in flux_dict.keys():
   print(key)
   #if (('C17O' in key) or (key =='H2CO') or ('CN' in key)):
   if (('CN' in key) or ('C18O' in key)):
      radius=1.0
   else:
      radius=0.4
   flux_dict[key]['flux'],flux_dict[key]['e_flux'],flux_dict[key]['flux_3sig'],flux_dict[key]['e_flux_3sig'],flux_dict[key]['rms'],flux_dict[key]['flux_beam'],flux_dict[key]['flux_beam_error'],\
   flux_dict[key]['nbeams'],flux_dict[key]['npix'] = get_flux(flux_dict[key]['filename'],\
                                                                  coord,radius,exclude_inner=False)
   if key == 'HDO_241':
      flux_dict[key]['flux']=flux_dict[key]['flux']-0.8*0.0232*0.4


pickle.dump(flux_dict, open("flux_dict.pickle", "wb"))


for key in flux_dict.keys():
   print(key,flux_dict[key]['flux'],flux_dict[key]['e_flux'],flux_dict[key]['flux_3sig'],flux_dict[key]['e_flux_3sig'],flux_dict[key]['rms'])




